April 22, 2025
\[ \newcommand\hbb{{\hat{\boldsymbol \beta}}} \newcommand\bb{{\boldsymbol \beta}} \newcommand\expn{{\frac{1}{N} \sum \limits_{i = 1}^N}} \newcommand\sumk{\sum \limits_{k = 1}^K} \newcommand\argminb{\underset{\bb}{\text{argmin }}} \newcommand\argmaxb{\underset{\bb}{\text{argmax }}} \newcommand\gtheta{\mathbf g(\boldsymbol \theta)} \newcommand\htheta{\mathbf H(\boldsymbol \theta)} \]
Model performance \(\neq\) success
Simplicity reigns supreme?
When we get too simple, we miss the ability to meaningfully make predictions!
How can we make deep learning interpretable?
Who cares?
We all should.
Safety: Detect and prevent critical failures
Example: Autonomous driving misclassifications
Fairness: Reveal and mitigate biases to ensure equitable treatment across groups
Why does a prediction get made?
How can we determine how to fix problems if our data is inherently biased?
Example: Google Photo Tagging Problems
Privacy/Security: Ensuring that sensitive information in the data is protected
Given the richness of some models, is it possible to use it in an adversarial way to make it do bad things?
Is it possible to back-engineer private information?
Example: Tay the Chatbot
Legal: We have to
The GDPR’s “right to explanation”
People have the right to be given an explanation for an output of an algorithm that has been used to determine an outcome
Especially true for decisions that impact your life (credit scores, social scores, etc.)
Opponents say that this requirement will stifle innovation
Problem: Neural Networks are complicated
\[ \mathbf X \rightarrow g(\cdot) \rightarrow \hat{y} \]
The standard flat NN:
\[ \hat{y} \propto \mathbf W \varphi(\mathbf W \varphi (...)) \]
How can we interpret anything about how a decision was made?
We’re going to cover a number of approaches today!
A rapidly evolving area
A pretty good definition of interpretability (Biran and Cotton, 2017):
Interpretability is the degree to which a human can understand the cause of a decision
Another good one (Kim, Khanna, and Koyejo, 2016):
… a user can correctly and efficiently predict the method’s results.
To what extent can we do this with modern ML methods?
Best place to start is with a class of models that are inherently interpretable
The model can provide all of the info necessary for a human to replicate exactly how a prediction was made with pen and paper
What are some examples of glass-box models that you can think of?
Classic example: Linear Regression
\[ \hat{y} = g(\mathbf X \hat{\boldsymbol \beta}) \]
NOTE: We’re talking about causality in terms of the prediction. Not in terms of actual outcomes.
Changes in a value necessarily result in changes in the prediction
Doesn’t mean that the change will correspond with a change
This is an example of a glass-box model because the structure makes it inherently easy to see how things change as a function of each feature!
Other examples:
Simple decision trees
Ridge/LASSO
Generalized Additive Models
What is the generic problem with these predictive models?
A trade-off exists:
Any intuition as to why?
Interpretation requires functional specification
Functional specification = Assumptions
Modern machine learning requires models that can uncover the correct form regardless of the form given enough data!
Modern ML requires the usage of universal approximators
Problem: These methods are typically local in nature!
Random Forests: Local rectangles
KNN: Local neighborhoods
Gaussian SVMs: Local distance structures
NNs: Local linear approximations
No one-size fits all method to interpret.
What do we do?
Two main approaches:
Model agnostic post-hoc interpretability: apply methods using the trained model to better understand how predictions are made
Model specific - examine parts of a model to understand (gradients, LIME)
Global approaches - examine the behavior of predictions on average using the prediction machine
Intrinsic Interpretability: change how the model is estimated to create windows that show how predictions are made
While methods exist for other ML algorithms, we’re going to completely concentrate on neural networks
Methods like XGBoost suffer from the same issue
One approach to interpretation is to only expect that we can explain predictions locally
Defense: the whole point is that the model is so complex that local explanations are really the only thing that matters
A complicated combination of credit score, income, etc. leads to mortgage approval
Explain individual predictions made by a black box!
Generic approach:
\(\hat{y} = g(\mathbf x)\)
I don’t know exactly what \(g()\) is, but I can quickly query it.
I want to know how changes in the feature values for my example result in changes to the prediction
Intuitive approach:
Let \(\mathbf x'\) be a feature vector that is close to \(\mathbf x\) associated with \(g(\mathbf x') = \hat{y}'\) where one value of \(\mathbf x\) is changed
If \(\hat{y}\) and \(\hat{y}'\) are really different, then that feature must have been important in making the prediction!
How can we quantify this argument and apply it to higher dimensional settings?
Let’s get more specific:
We have a target instance \(\mathbf x \in \mathbb R^P\)
\[ \hat{\boldsymbol \beta} = \underset{\{\beta_0 , \boldsymbol \beta \}}{\text{argmin }} \sum \limits_{i = 1}^N w_i(\hat{y}_i - \beta_0 - \mathbf x_i' \boldsymbol \beta)^2 + \lambda \sum \limits_{j = 1}^P | \beta_j | \]
Local Interpretable Model-Agnostic Explanations is an approach that creates a local linear approximation to the predictive function in the neighborhood of the point of interest
The corresponding coefficients to this weighted LASSO correspond to the loss minimizing importance function conditional on \(\lambda\)
Any thoughts on the intuition of this?
Ultimately, just trying to visualize the predictive surface in the neighborhood around the point!
Works in any scenario where we can generate a neighborhood of points around the example.
A little tricky for images and text!
For any model estimated using PyTorch, we can use the captum library in Python to run a LIME model!
Strengths:
Explanations are short and easy to understand
Easy to implement
Weaknesses:
Very sensitive to neighborhood choice
“Noising” method can lead to really unlikely data points due to strong dependencies between features
Can only be applied one point at a time
What if I want a more global measure of feature importance?
How sensitive is this prediction to small perturbations?
How much does this feature contribute to explaining the predictions for all data in my data set?
Local sensitivity is important for explaining individual predictions
Not as important in the case where we want to think about feature importance more in the regression sense
Think LASSO
Another way to think about feature importance
Including this feature significantly reduces the overall loss of the predictive function
Doesn’t answer the question of how a feature contributes to a prediction
Does answer the question of how much a feature contributes to a prediction
How could we assess this globally in a model agnostic way?
One approach: Leave One Feature Out fitting
Problem: Takes way too long!!!!
Slightly better approach: permutation importance
Important global model agnostic method!
Can (and should) be applied to any black box model on tabular data!
The default method for RFs, XGBoost, and Tabular NNs
Problem: Doesn’t really apply to images and text.
LIME and Permutations are good for all models
More complex inputs requires more complex methods!
Let’s look at an approach for image classification.
What makes this a picture of a dog?
Is there an agnostic way that we can know what pixels correspond most heavily with a certain class label?
For image classification problems, the final layer of the CNN produces scores for each image (e.g. the predictions)
\[ \mathbf s = (s_1,s_2,...,s_C) \]
By definition, we construct deep learning models in a way that they are fully brackpropable
This means that for any image, \(\mathbf x_i\), we can compute:
\[ \frac{\partial s_j}{\partial \mathbf x_i} \]
Since the gradient of a scalar w.r.t a complex input takes on the form of the complex input, our gradient of the score is then a pixel-by-pixel set of values that correspond to how much a change in one pixel value (taking into account the convolutional structure) changes the score!
The Vanilla Gradient Methods speeds up the process of creating LIME maps for images by using gradients!
Via a first-order Taylor series expansion:
\[ s_c(\mathbf x) \approx \mathbf w^T \mathbf x + b \]
where:
\[ \mathbf w = \frac{\partial s_c}{\partial \mathbf x}\rvert_{\mathbf x_0} \]
Vanilla Gradients:
Problem: Vanilla gradients tend to be really sensitive and noisy
Reason: Derivatives fluctuate greatly at small scales
No incentive to smooth gradients for a model that fits the training data well
Overfit w.r.t. to billions of parameters, but works well when put together
Solution: SmoothGrad
Adding a little noise to see where gradients are randomly fluctuating will help to smooth the process!
Vanilla Gradients work decently, but can be jumpy due to the pixel level feature importance mapping
An alternative approach attempts to use the feature maps present in CNNs to apply strict spatial semantic continuity in the salience map!
The final layer before the fully connected classification head carries a lower dimensional, denoised representation of the original image!
Alter the approach to only compute the gradient at the final convolutional layer and upscale that back to the original size of the image.
Gradient Weighted Class Activation Mappings (Grad-CAM) does the following:
Compute the class scores for an input image
Assume that the final convolutional layer yields a tensor of size \(k \times k \times D\). For each \(k \times k\) filter, compute the gradient of the score w.r.t that filter.
Across all \(D\) filters, globally average pool the “pixels” in the gradient. This finds which of the \(D\) filters have a large influence on the class label. This yields a filter weight, \(\alpha_d\)
Create the coarse heatmap
\[ \mathbf x_{ij} = \text{ReLU}\left(\sum \limits_{d = 1}^D \alpha_d x_{i,j,d} \right) \]
Bilinearly upscale the heatmap to the original image size
This approach is often too coarse, so guide the approach by combining Vanilla Gradients with Grad-CAM to get guided Grad-CAM.
Simple fix:
Works decently well to get pixel level maps that explain why an image was classified as a dog!
Saliency maps work well for images. What can we do for text?
Broadly, the only models worth thinking about here are those that rely on self-attention
Let’s think about how we might be able to do something similar for text classification using BERT
Goal: Given an input sentence, \(\mathbf x\), understand how words relate to one another and ultimately lead to a particular classification!
Solution: Check out the attention weights w.r.t to the classification token
Problem: Multiheaded attention means that there are multiple attention heads in each layer
Problem: Multiple layers of attention - base BERT has 12!!!!!
\[ \begin{array}{c|ccccccc} & \mathrm{[CLS]} & I & \mathrm{loved} & \mathrm{this} & \mathrm{movie} & ! & \mathrm{[SEP]} \\ \hline \mathrm{[CLS]} & 0.05 & 0.03 & 0.30 & 0.05 & 0.40 & 0.05 & 0.12 \\ I & 0.01 & 0.10 & 0.75 & 0.05 & 0.03 & 0.03 & 0.03 \\ \mathrm{loved} & 0.02 & 0.01 & 0.05 & 0.60 & 0.25 & 0.04 & 0.03 \\ \mathrm{this} & 0.01 & 0.02 & 0.03 & 0.10 & 0.80 & 0.02 & 0.02 \\ \mathrm{movie} & 0.02 & 0.01 & 0.02 & 0.10 & 0.05 & 0.80 & 0.00 \\ ! & 0.05 & 0.05 & 0.05 & 0.05 & 0.05 & 0.05 & 0.70 \\ \mathrm{[SEP]} & 0.80 & 0.04 & 0.04 & 0.04 & 0.04 & 0.02 & 0.02 \\ \end{array} \]
Each head and layer is associated with its own attention matrix.
Instead, aggregate in some meaningful way.
Aggregate within layer by averaging attention weight over all heads
Aggregate across model by averaging layer averages over all layers
Lose clarity!
Note: At each aggregation step, renormalize so that rows add up to 1!
More clever aggregation via roll-out
Transformers add a residual (identity) connection around each attention block. We can mimic this by defining an augmented layer-wise attention matrix (averaged over layer):
\[ \tilde{\mathbf A}^\ell = \mathbf A^{\ell} + \mathcal I \]
Why add identity?
Each token actually is allowed to attend to itself via a skip path. The raw layer attention matrices don’t take this connection into account.
Keeps the attention model from collapsing on itself.
After re-normalization (approximating what a transformer does), we can define a rolled-out version of attention weights that track attention across the entire run on the self attention layers as:
\[ \mathbf R = \hat{\mathbf A^{(1)}} \times \hat{\mathbf A^{(2)}} \times ... \times \hat{\mathbf A^{(L)}} \]
\(\mathbf R\) tells us the total mass of attention flowing from token \(i\) in the input all the through the network to token \(j\)
High values mean that there is a lot of “attention” given to \(j\) by \(i\) in the network
BERT’s class token, then can track attention to the class label all the way through!
Cost: Aggregation loses specificity!
Small local patterns can be lost due to aggregation
If only our eyes could reasonably see in 120 dimensions at once…
Causal importance (in the predictive sense) is not directly encoded here
Too aggregated generally
But, can be useful for locating general patterns
All of these methods are post-hoc importance metrics.
Train the model
Apply a method
Tease out importance from the model
What do you think is the big weakness of this approach?
The clarity of explanations from regression models comes from their construction!
Clarity is a part of the model structure and loss functions
Reward clarity in model training
Think about house-training a puppy
When the puppy doesn’t pee in the house, we want to reward it
Two strategies:
Give the dog treats at the end of the day when it doesn’t pee in the house
Give the dog treats immediately after it pees outside. Give light negative reinforcement when the puppy pees in the house.
Which would work better?
Ultimately, we can’t expect a model to learn interpretable structures if we don’t tell it to do so!
Regression approaches:
Add L1 penalties on the weights
Simplify model form
Why won’t these work for modern deep learning architectures?
Clever change 1: Mixtures of Experts for FFNNs
Consider a flat layer in a neural network with 4000 hidden units
We follow up each layer with a ReLU activation function.
Remember that a weight of zero (or hidden representation) is equivalent to not including it in the model at all
Can we restructure our NN architecture to impose this sparsity constraint in a nicer way?
Consider dividing the 4000 hidden units into 10 sets of 400.
At full capacity, switch on all 10 blocks at once.
If we train the model right, though, each block of 400 units can be encouraged to build expertise for a particular type of input!
Still allow functional variation to exist
But, make variation for larger scale concepts fall into one or more of the blocks!
General strategy:
Rationale for the Mixture of Experts approach:
Significant decrease in training time for insanely large models
When backpropping for a specific training example, each gradient w.r.t. to a layer corresponds only to elements that are not fixed
Consider a MoE that chooses the top 2 experts for each example
Total operations for full model: 4000 x N
Total operations for MoE: (2 x 400) x N x Router Network
Same number of overall parameters with less train time needed!
Through specialization, we get interpretability and scalability together
Important point: these two things don’t have to be separate!
If we design it to be interpretable, it’ll be interpretable.
Just need more clever design strategies.
Concept bottleneck models
KNN Retrieval Models
At inference time, retrieve the top‑k most similar examples from a stored memory of training data in the encoder space and aggregate their labels (or representations) to make a prediction
Sorta like a deep KNN model
Near examples = some amount of interpretability!
This is the way forward to ensuring that modern AI is human-interpretable
Gotta teach the puppy not to pee in the house!
It won’t just learn it on its own.
Big money in this area right now.